{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Content-based recommender using DSSM\n",
    "An example of how to build a Deep Structured Semantic Model (DSSM) for incorporating complex content-based features into a recommender system.  See [Learning Deep Structured Semantic Models for Web Search using Clickthrough Data](https://www.microsoft.com/en-us/research/publication/learning-deep-structured-semantic-models-for-web-search-using-clickthrough-data/).  This example does not attempt to provide a datasource or train a model, but merely show how to structure a complex DSSM network."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "import mxnet as mx\n",
    "import symbol_alexnet as alexnet\n",
    "import recotools"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# Define some constants\n",
    "max_user = int(1e6)\n",
    "title_vocab = int(1e5)\n",
    "ngram_dimensions = int(1e8)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false,
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "def dssm_recommender(k):\n",
    "    # input variables\n",
    "    title = mx.symbol.Variable('title_words')\n",
    "    image = mx.symbol.Variable('image')\n",
    "    queries = mx.symbol.Variable('query_ngrams')\n",
    "    user = mx.symbol.Variable('user_id')\n",
    "    label = mx.symbol.Variable('label')\n",
    "    \n",
    "    # Process content stack\n",
    "    image = alexnet.features(image, 256)\n",
    "    title = recotools.SparseBagOfWordProjection(data=title, vocab_size=title_vocab, \n",
    "                                    output_dim=k)\n",
    "    title = mx.symbol.FullyConnected(data=title, num_hidden=k)\n",
    "    content = mx.symbol.Concat(image, title)\n",
    "    content = mx.symbol.Dropout(content, p=0.5)\n",
    "    content = mx.symbol.FullyConnected(data=content, num_hidden=k)\n",
    "    \n",
    "    # Process user stack\n",
    "    user = mx.symbol.Embedding(data=user, input_dim=max_user, output_dim=k) \n",
    "    user = mx.symbol.FullyConnected(data=user, num_hidden=k)\n",
    "    queries = recotools.SparseBagOfWordProjection(data=queries, vocab_size=ngram_dimensions, \n",
    "                                            output_dim=k)\n",
    "    queries = mx.symbol.FullyConnected(data=queries, num_hidden=k)\n",
    "    user = mx.symbol.Concat(user,queries)\n",
    "    user = mx.symbol.Dropout(user, p=0.5)\n",
    "    user = mx.symbol.FullyConnected(data=user, num_hidden=k)\n",
    "    \n",
    "    # loss layer\n",
    "    pred = recotools.CosineLoss(a=user, b=content, label=label)\n",
    "    return pred\n",
    "\n",
    "net1 = dssm_recommender(256)\n",
    "mx.viz.plot_network(net1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [Root]",
   "language": "python",
   "name": "Python [Root]"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
